"""Credential manager."""

import argparse
import functools
import sys
import textwrap
import typing

from cki_lib import logger
from cki_lib import misc

from . import all_tokens  # noqa: F401, pylint: disable=unused-import
from . import secrets
from . import token
from . import utils

LOGGER = logger.get_logger('cki_tools.credentials.manager')


class Manager:
    """Credential management."""

    def __init__(
        self,
        token_name: str,
        token_type_prefix: str,
        force: bool,
    ) -> None:
        """Credential management."""
        self.token_name = token_name
        self.token_type_prefix = token_type_prefix
        self.force = force

    @functools.cached_property
    def active_token_groups(self) -> set[str]:
        """Return active managed token group names (without versions)."""
        return {k.split('/')[0] for k in self.active_tokens.keys()}

    @functools.cached_property
    def active_tokens(self) -> dict[str, typing.Any]:
        """Return active managed tokens."""
        return {
            k: v for k, v in self.managed_tokens.items()
            if misc.get_nested_key(v, 'meta/active')
        }

    @functools.cached_property
    def inactive_tokens(self) -> dict[str, typing.Any]:
        """Return inactive managed tokens."""
        return {
            k: v for k, v in self.managed_tokens.items()
            if not misc.get_nested_key(v, 'meta/active')
        }

    @functools.cached_property
    def managed_tokens(self) -> dict[str, typing.Any]:
        """Return managed tokens."""
        return {
            k: v for k, v in secrets.read_secrets_file().items()
            if (not self.token_name or k == self.token_name or k.startswith(f'{self.token_name}/'))
            and misc.get_nested_key(v, 'meta/token_type') in token.TOKEN_REGISTRY
            and misc.get_nested_key(v, 'meta/token_type').startswith(self.token_type_prefix)
        }

    def create(self) -> int:
        """Create tokens."""
        if not self.token_name:
            print('create requires a single token')
            return 1

        _, token_version = utils.split_token_name(self.token_name)
        token.get_token(self.token_name, self.force).create(token_version)
        return 0

    def destroy(self) -> int:
        """Destroy tokens."""
        if not self.token_name:
            print('destroy requires a single token')
            return 1

        _, token_version = utils.split_token_name(self.token_name)
        token.get_token(self.token_name, self.force).destroy(token_version)
        return 0

    def rotate(self) -> int:
        """Rotate tokens."""
        if not self.token_name:
            print('rotate requires a single token')
            return 1

        _, token_version = utils.split_token_name(self.token_name)
        token.get_token(self.token_name, self.force).rotate(token_version)
        return 0

    def update(self) -> int:
        """Update the secret meta information about tokens."""
        if not self.active_tokens:
            print('no matching tokens found')
            return 1

        for token_name in self.active_tokens.keys():
            _, token_version = utils.split_token_name(token_name)
            token.get_token(token_name, self.force).update(token_version)
        return 0

    def validate(self) -> int:
        """Check validity of the token."""
        if not self.active_tokens:
            print('no matching tokens found')
            return 1

        for token_name in self.active_tokens.keys():
            _, token_version = utils.split_token_name(token_name)
            token.get_token(token_name, self.force).validate(token_version)
        return 0

    def purge(self) -> int:
        """Purge inactive tokens."""
        if not self.inactive_tokens:
            print('no matching tokens found')
            return 1

        for token_name in self.inactive_tokens.keys():
            _, token_version = utils.split_token_name(token_name)
            token.get_token(token_name, self.force).purge(token_version)
        return 0

    def status(self) -> int:
        """Print statistics about tokens."""
        if not self.active_tokens:
            print('no matching tokens found')
            return 1

        needs_rotate_all, needs_prepare_all, needs_clean_all = set(), set(), set()
        for token_group in self.active_token_groups:
            token_instance = token.get_token(token_group, self.force)
            needs_rotate = token_instance.check_needs_rotate()
            needs_prepare = token_instance.check_needs_prepare()
            needs_clean = token_instance.check_needs_clean()
            if needs_rotate or needs_prepare or needs_clean:
                LOGGER.debug('%s: needs_rotate=%s needs_prepare=%s needs_clean=%s',
                             token_group, needs_rotate, needs_prepare, needs_clean)
            needs_rotate_all |= {token_group} if needs_rotate else set()
            needs_prepare_all |= {token_group} if needs_prepare else set()
            needs_clean_all |= {token_group} if needs_clean else set()
        # by default, only show tokens that need to be prepared if they also need to be rotated
        if not self.force:
            needs_prepare_all &= needs_rotate_all
        print(f'Tokens that need to be rotated: {list(needs_rotate_all)}')
        print(f'Tokens that need to be prepared: {list(needs_prepare_all)}')
        print(f'Tokens that need to be cleaned: {list(needs_clean_all)}')
        return 0

    def _confirm_affected_tokens(self) -> int:
        """
        Print how many tokens will be affected and asks for confirmation.

        If the user replies with "y" or "yes", returns 0. Otherwise, returns 1.
        """
        msg = (
            f"You are about to update {len(self.active_token_groups)} tokens:\n"
            f"{textwrap.indent('\n'.join(self.active_token_groups), prefix='    - ')}\n"
            "Are you sure you want to proceed?"
        )
        if not self.token_name and not utils.confirm(msg):
            print("Interrupting!")
            return 1

        return 0

    def prepare(self) -> int:
        """Prepare tokens."""
        if ret := self._confirm_affected_tokens():
            return ret

        for token_group in self.active_token_groups:
            token.get_token(token_group, self.force).prepare()
        return 0

    def switch(self) -> int:
        """Switch the deployed tokens."""
        if ret := self._confirm_affected_tokens():
            return ret

        for token_group in self.active_token_groups:
            token.get_token(token_group, self.force).switch()
        return 0

    def clean(self) -> int:
        """Clean tokens."""
        if ret := self._confirm_affected_tokens():
            return ret

        for token_group in self.active_token_groups:
            token.get_token(token_group, self.force).clean()
        return 0


def main(args: list[str] | None = None) -> int:
    """Run main loop."""
    parser = argparse.ArgumentParser(description='Manage the lifecycle of secrets')
    parser.add_argument('action', choices=[
        'create', 'destroy', 'rotate',
        'update', 'validate', 'purge',
        'status', 'prepare', 'switch', 'clean',
    ], help='What to do')
    parser.add_argument('--token-name', help='Single token name')
    parser.add_argument('--token-type-prefix', default='', help='Filter by token type prefix')
    parser.add_argument('--force', action='store_true',
                        help='Force token rotation even if new enough')
    parsed_args = parser.parse_args(args)

    return typing.cast(int, getattr(Manager(
        parsed_args.token_name,
        parsed_args.token_type_prefix,
        parsed_args.force,
    ), parsed_args.action)())


if __name__ == '__main__':
    sys.exit(main())
